变分推断通过近似后验分布,将复杂的概率推断转化为优化问题,是贝叶斯深度学习和生成模型的核心技术。

核心问题

贝叶斯推断中,后验分布 的计算涉及难以处理的边际似然:

分母 (证据/边际似然)需要对所有隐变量积分,在高维或复杂模型中不可计算。

变分推断思想

寻找简单的近似分布 来逼近真实后验 ,最小化 KL 散度:

证据下界(ELBO)

两种等价形式

似然与先验

证据与熵

组成部分作用
重构项鼓励模型解释数据
KL 散度/熵保持分布简约,防止过拟合

VAE 架构

组件功能
编码器输入 ,输出 $q(Z
解码器输入 ,输出重构的

重参数化技巧

将随机性转移到 ,使梯度可反向传播。

应用场景

场景说明
图像/文本生成从隐空间采样生成新数据
贝叶斯神经网络权重不确定性量化
异常检测异常数据重构损失大
主题模型LDA 的快速推断

前沿方向

  • 归一化流:构建更灵活的后验近似
  • 扩散模型联系:与层级 VAE 的数学关联
  • 后验坍塌:KL 退火等解决方案

PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class VAE(nn.Module):
    def __init__(self, latent_dim=20, hidden_dim=400, image_size=784):
        super().__init__()
        self.fc1 = nn.Linear(image_size, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, image_size)
 
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)
 
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
 
    def decode(self, z):
        return torch.sigmoid(self.fc4(F.relu(self.fc3(z))))
 
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
 
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

推荐资源